package com.swapper.math.fsa;

import java.util.*;
import java.util.stream.Collectors;

public final class NFA {
  private final Set<State> states = State.createGroup();
  private final Set<String> inputs = new LinkedHashSet<>();

  private NFA(Set<State> states) {
    this.states.addAll(states);
    for (State state : states) {
      inputs.addAll(state.getInputs());
    }
    inputs.remove("");
  }

  public NFA(int count, Set<Integer> finallyIds) {
    for (int i = 0; i < count; ++i) {
      final int index = i;
      boolean isFinally = finallyIds.stream().anyMatch(id -> id == index);
      states.add(new State(index, index == 0, isFinally));
    }
  }

  public State state(final int id) {
    Optional<State> first = states.stream().filter(s -> id == s.getId()).findFirst();
    if (first.isEmpty()) {
      throw new IllegalArgumentException("not find id.");
    }
    return first.get();
  }

  public NFA transition(int id, String input, final Set<Integer> targets) {
    if (input == null) {
      throw new IllegalArgumentException("input is null.");
    }
    if (targets == null) {
      throw new IllegalArgumentException("targets is null.");
    }
    inputs.add(input);
    Set<State> targetStates = states.stream()
            .filter(s -> targets.contains(s.getId()))
            .collect(Collectors.toSet());
    state(id).translation(input, targetStates);
    return this;
  }

  /**
   * 确定化： NFA -> DFA
   * 1.创建一个有状态集合到新状态的映射表M；
   * 2.获取初始状态(0)的epsilon闭包(状态集G)，作为新的初态S0；
   * 3.有G状态集开始，递归执行下列逻辑：
   * 判断状态集G在映射表M中是否作为键存在：
   * 若存在，则说明G状态集已经考察过了；
   * 否则，说明G是一个新的状态集，创建一个新的状态Si(i是由0递增的序号)，并加入映射表M；
   * 然后，对状态集G的每一个非epsilon输入对于的响应状态集进行考察；
   * 之后，给新状态Si设置对应的转换关系。
   * 4.最后，将映射表中的状态作为新的状态集构建DFA。
   *
   * @return DFA
   */
  public NFA determine() {
    Map<Set<State>, State> map = new LinkedHashMap<>();
    dfaInner(map, state(0).epsilonClosure());
    List<State> list = new ArrayList<>(map.values());
    list.sort((s1, s2) -> {
      if (s1.isInitial()) {
        if (s2.isInitial()) {
          return s1.getId() - s2.getId();
        } else {
          return -1;
        }
      } else if (s1.isFinally()) {
        if (s2.isFinally()) {
          return s1.getId() - s2.getId();
        } else {
          return 1;
        }
      } else {
        return s1.getId() - s2.getId();
      }
    });
    return new NFA(new LinkedHashSet<>(map.values()));
  }

  private void dfaInner(Map<Set<State>, State> map, Set<State> group) {
    if (!group.isEmpty() && !map.containsKey(group)) {
      boolean isInitial = group.stream().anyMatch(State::isInitial);
      boolean isFinally = group.stream().anyMatch(State::isFinally);
      State state = new State(map.size(), isInitial, isFinally);
      map.put(group, state);
      for (String input : inputs) {
        if (!input.equals("")) {
          Set<State> targets = State.shift(group, input);
          dfaInner(map, targets);
          State findState = map.get(targets);
          if (findState != null) {
            state.translation(input, Set.of(findState));
          } else {
            state.translation(input, Collections.emptySet());
          }
        }
      }
    }
  }

  /**
   * DFA最小化，去除多余的状态
   *
   * @return 最小化的DFA
   */
  public NFA minimize() {
    // 将状态集分为终态集s1和非终态集s2
    Set<State> s1 = State.createGroup();
    Set<State> s2 = State.createGroup();
    for (State state : states) {
      if (state.isFinally()) {
        s1.add(state);
      } else {
        s2.add(state);
      }
    }
    // 创建现存区current并加入s1，s2
    Set<Set<State>> current = new LinkedHashSet<>();
    current.add(s1);
    current.add(s2);
    // 对当前现存区进行递归考察
    minimizeInner(current);
    if (current.contains(s1) && current.contains(s2)) {
      return this;
    }
    // 根据考察
    Map<Set<State>, State> map = new LinkedHashMap<>();
    for (Set<State> group : current) {
      boolean isInitial = group.stream().anyMatch(State::isInitial);
      boolean isFinally = group.stream().anyMatch(State::isFinally);
      map.put(group, new State(map.size(), isInitial, isFinally));
    }
    for (Set<State> group : current) {
      for (String input : inputs) {
        Optional<Set<State>> first = current.stream()
                .filter(g -> g.containsAll(State.shift(group, input)))
                .findFirst();
        if (first.isEmpty()) {
          throw new IllegalStateException("not find shift");
        }
        map.get(group).translation(input, Set.of(map.get(first.get())));
      }
    }
    return new NFA(new LinkedHashSet<>(map.values()));
  }

  public void minimizeInner(Set<Set<State>> current) {
    Set<Set<State>> current0 = current.stream()
            .filter(g -> g.size() > 1)
            .collect(Collectors.toSet());
    if (current0.isEmpty()) {
      return;
    }
    for (Set<State> group : current0) {
      for (String input : inputs) {
        if (current.stream().noneMatch(g -> g.containsAll(State.shift(group, input)))) {
          Map<Set<State>, Set<State>> map = new LinkedHashMap<>();
          for (State state : group) {
            Optional<Set<State>> first = current.stream()
                    .filter(g -> g.containsAll(state.shift(input)))
                    .findFirst();
            if (first.isEmpty()) {
              throw new IllegalStateException("not find shift");
            }
            Set<State> key = first.get();
            Set<State> value = map.getOrDefault(key, State.createGroup());
            value.add(state);
            map.put(key, value);
          }
          current.remove(group);
          current.addAll(map.values());
          minimizeInner(current);
          break;
        }
      }
    }
  }

  @Override
  public String toString() {
    Table table = new Table(states.size() + 1, inputs.size() + 1);
    table.set(0, 0, "S\\I");
    int row = 0;
    int col = 0;
    for (String input : inputs) {
      table.set(row, ++col, Objects.equals(input, "") ? "ε" : input);
    }
    row = 1;
    col = -1;
    for (State state : states) {
      table.set(row, ++col, state.toString());
      for (String input : inputs) {
        table.set(row, ++col, state.shift(input).toString());
      }
      ++row;
      col = -1;
    }
    return table.toString();
  }
}
